%% FUNCTION Least_TGL
% L21 Joint Feature Learning with Least Squares Loss.
%
%% OBJECTIVE
% argmin_W {  (0.5 * norm (Y{2} - X{2}' * W(:, 2))^2)+(0.5 * norm (Y{1} - X{1}' * W(:, 1)-C)^2)
%            + opts.rho_L2 * \|W\|_2^2 + rho1 * \|W\|_{2,1} }
%
%% INPUT
% X: {n * d} * t - input matrix
% Y: {n * 1} * t - output matrix
% rho1: L2,1-norm group Lasso parameter.
% optional:
%   opts.rho_L2: L2-norm parameter (default = 0).
%
%% OUTPUT
% W: model: d * t
% C: constant
% funcVal: function value vector.
%
%% LICENSE
%   This program is free software: you can redistribute it and/or modify
%   it under the terms of the GNU General Public License as published by
%   the Free Software Foundation, either version 3 of the License, or
%   (at your option) any later version.
%
%   This program is distributed in the hope that it will be useful,
%   but WITHOUT ANY WARRANTY; without even the implied warranty of
%   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%   GNU General Public License for more details.
%
%   You should have received a copy of the GNU General Public License
%   along with this program.  If not, see <http://www.gnu.org/licenses/>.
%
%   Copyright (C) 2011 - 2012 Jiayu Zhou and Jieping Ye
%
%   You are suggested to first read the Manual.
%   For any problem, please contact with Jiayu Zhou via jiayu.zhou@asu.edu
%
%   Last modified on June 3, 2012.
%
%% RELATED PAPERS
%
%   [1] Evgeniou, A. and Pontil, M. Multi-task feature learning, NIPS 2007.
%   [2] Liu, J. and Ye, J. Efficient L1/Lq Norm Regularization, Technical
%       Report, 2010.
%
%% RELATED FUNCTIONS
%  Least_L21, init_opts

%% Code starts here
function [W, C,funcVal] = least_L21(X, Y, rho1, opts)

if nargin <3
    error('\n Inputs: X, Y, rho1, should be specified!\n');
end
X = multi_transpose(X);

if nargin <4
    opts = [];
end

% initialize options.
opts=init_opts(opts);

if isfield(opts, 'rho_L2')
    rho_L2 = opts.rho_L2;
else
    rho_L2 = 0;
end
C0_prep=0;
if isfield(opts,'C0')
        C0=opts.C0;
else
        C0=C0_prep;
end
task_num  = 2;
dimension = size(X{1}, 1);
funcVal = [];

% initialize a starting point
if isfield(opts,'W0')
    W0=opts.W0;
    if (nnz(size(W0)-[dimension, task_num]))
        error('\n Check the input .W0');
    end
elseif opts.init==2
    W0 = zeros(dimension, task_num);
elseif opts.init == 0
    XY = cell(task_num, 1);
    W0_prep = [];
    for t_idx = 1: task_num
        XY{t_idx} = X{t_idx}*Y{t_idx};
        W0_prep = cat(2, W0_prep, XY{t_idx});
    end
    W0 = W0_prep;
end

bFlag=0; % this flag tests whether the gradient step only changes a little


Wz= W0;
Cz= C0;
Wz_old = W0;
Cz_old = C0;

t = 1;
t_old = 0;

iter = 0;
gamma = 1;
gamma_inc = 2;

while iter < opts.maxIter
    alpha = (t_old - 1) /t;
    
    Ws = (1 + alpha) * Wz - alpha * Wz_old;
    Cs = (1 + alpha) * Cz - alpha * Cz_old;

    % compute function value and gradients of the search point
    [gWs,gCs]  = gradVal_eval(Ws,Cs);
    Fs   = funVal_eval (Ws,Cs);
    
    while true
        Wzp = FGLasso_projection(Ws - gWs/gamma, rho1 / gamma);
        Czp = Cs - gCs/gamma;
        Fzp = funVal_eval  (Wzp,Czp);
        
        delta_Wzp = Wzp - Ws;
        delta_Czp = Czp - Cs;
        nrm_delta_Wzp = norm(delta_Wzp, 'fro')^2;
        nrm_delta_Czp = norm(delta_Czp, 'fro')^2;
        r_sum = (nrm_delta_Wzp+nrm_delta_Czp)/2;
        %         Fzp_gamma = Fs + trace(delta_Wzp' * gWs)...
        %             + gamma/2 * norm(delta_Wzp, 'fro')^2;
        Fzp_gamma = Fs + sum(sum(delta_Wzp.* gWs))...
            + sum(sum(delta_Czp .* gCs))...
            + gamma/2 * nrm_delta_Wzp ...
            + gamma/2 * nrm_delta_Czp;
        
        if (r_sum <=1e-20)
            bFlag=1; % this shows that, the gradient step makes little improvement
            break;
        end
        
        if (Fzp <= Fzp_gamma)
            break;
        else
            gamma = gamma * gamma_inc;
        end
    end
    
    Wz_old = Wz;
    Cz_old = Cz;
    Wz = Wzp;
    Cz = Czp;
    
    funcVal = cat(1, funcVal, Fzp + nonsmooth_eval(Wz, rho1));
    
    if (bFlag)
        % fprintf('\n The program terminates as the gradient step changes the solution very small.');
        break;
    end
    
    % test stop condition.
    switch(opts.tFlag)
        case 0
            if iter>=2
                if (abs( funcVal(end) - funcVal(end-1) ) <= opts.tol)
                    break;
                end
            end
        case 1
            if iter>=2
                if (abs( funcVal(end) - funcVal(end-1) ) <=...
                        opts.tol* funcVal(end-1))
                    break;
                end
            end
        case 2
            if ( funcVal(end)<= opts.tol)
                break;
            end
        case 3
            if iter>=opts.maxIter
                break;
            end
    end
    
    iter = iter + 1;
    t_old = t;
    t = 0.5 * (1 + (1+ 4 * t^2)^0.5);
    
end

W = Wzp;
C = Czp;
% private functions

    function [X] = FGLasso_projection (D, lambda )
        % l2.1 norm projection.
        X = repmat(max(0, 1 - lambda./sqrt(sum(D.^2,2))),1,size(D,2)).*D;
    end

% smooth part gradient.
    function [grad_W,grad_C] = gradVal_eval(W,C)
        if opts.pFlag
            grad_W = zeros(zeros(W));
            for i = 1:task_num
                if(i==1)
                grad_W (i, :) = X{i}*(X{i}' * W(:,i)-Y{i}+C);
                else
                grad_W (i, :) = X{i}*(X{i}' * W(:,i)-Y{i});
                end
            end
        else
            grad_W = [];
            for i = 1:task_num
                if(i==1)
                grad_W = cat(2, grad_W, X{i}*(X{i}' * W(:,i)-Y{i}+C) );
                else
                grad_W = cat(2, grad_W, X{i}*(X{i}' * W(:,i)-Y{i}) );
                end
            end
        end
        grad_W = grad_W+ rho_L2 * 2 * W;
        grad_C =-sum(X{1}' * W(:,1)-Y{1}+C);
    end

% smooth part function value.
    function [funcVal] = funVal_eval (W,C)
        funcVal = 0;
        if opts.pFlag
            for i = 1: task_num
                if(i==1)
                funcVal = funcVal + 0.5 * norm (Y{i} - X{i}' * W(:, i)-C)^2;
                else
                funcVal = funcVal + 0.5 * norm (Y{i} - X{i}' * W(:, i))^2;
                end
            end
        else
            for i = 1: task_num
                if(i==1)
                funcVal = funcVal + 0.5 * norm (Y{i} - X{i}' * W(:, i)-C)^2;
                else    
                funcVal = funcVal + 0.5 * norm (Y{i} - X{i}' * W(:, i))^2;
                end
            end
        end
        funcVal = funcVal + rho_L2 * norm(W,'fro')^2;
    end

    function [non_smooth_value] = nonsmooth_eval(W, rho_1)
        non_smooth_value = 0;
        if opts.pFlag
            parfor i = 1 : size(W, 1)
                w = W(i, :);
                non_smooth_value = non_smooth_value ...
                    + rho_1 * norm(w, 2);
            end
        else
            for i = 1 : size(W, 1)
                w = W(i, :);
                non_smooth_value = non_smooth_value ...
                    + rho_1 * norm(w, 2);
            end
        end
    end
end